{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp interpret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.data.all import *\n",
    "from fastai.optimizer import *\n",
    "from fastai.learner import *\n",
    "import sklearn.metrics as skm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interpretation of Predictions\n",
    "\n",
    "> Classes to build objects to better interpret predictions of a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def plot_top_losses(x, y, *args, **kwargs):\n",
    "    raise Exception(f\"plot_top_losses is not implemented for {type(x)},{type(y)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_all_ = [\"plot_top_losses\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Interpretation():\n",
    "    \"Interpretation base class, can be inherited for task specific Interpretation classes\"\n",
    "    def __init__(self, dl, inputs, preds, targs, decoded, losses):\n",
    "        store_attr(\"dl,inputs,preds,targs,decoded,losses\")\n",
    "\n",
    "    @classmethod\n",
    "    def from_learner(cls, learn, ds_idx=1, dl=None, act=None):\n",
    "        \"Construct interpretation object from a learner\"\n",
    "        if dl is None: dl = learn.dls[ds_idx]\n",
    "        return cls(dl, *learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=None))\n",
    "\n",
    "    def top_losses(self, k=None, largest=True):\n",
    "        \"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`).\"\n",
    "        return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)\n",
    "\n",
    "    def plot_top_losses(self, k, largest=True, **kwargs):\n",
    "        losses,idx = self.top_losses(k, largest)\n",
    "        if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)\n",
    "        if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)\n",
    "        else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))\n",
    "        b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))\n",
    "        x,y,its = self.dl._pre_show_batch(b, max_n=k)\n",
    "        b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))\n",
    "        x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)\n",
    "        if its is not None:\n",
    "            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)\n",
    "        #TODO: figure out if this is needed\n",
    "        #its None means that a batch knows how to show itself as a whole, so we pass x, x1\n",
    "        #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = synth_learner()\n",
    "interp = Interpretation.from_learner(learn)\n",
    "x,y = learn.dls.valid_ds.tensors\n",
    "test_eq(interp.inputs, x)\n",
    "test_eq(interp.targs, y)\n",
    "out = learn.model.a * x + learn.model.b\n",
    "test_eq(interp.preds, out)\n",
    "test_eq(interp.losses, (out-y)[:,0]**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ClassificationInterpretation(Interpretation):\n",
    "    \"Interpretation methods for classification models.\"\n",
    "\n",
    "    def __init__(self, dl, inputs, preds, targs, decoded, losses):\n",
    "        super().__init__(dl, inputs, preds, targs, decoded, losses)\n",
    "        self.vocab = self.dl.vocab\n",
    "        if is_listy(self.vocab): self.vocab = self.vocab[-1]\n",
    "\n",
    "    def confusion_matrix(self):\n",
    "        \"Confusion matrix as an `np.ndarray`.\"\n",
    "        x = torch.arange(0, len(self.vocab))\n",
    "        d,t = flatten_check(self.decoded, self.targs)\n",
    "        cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)\n",
    "        return to_np(cm)\n",
    "\n",
    "    def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap=\"Blues\", norm_dec=2,\n",
    "                              plot_txt=True, **kwargs):\n",
    "        \"Plot the confusion matrix, with `title` and using `cmap`.\"\n",
    "        # This function is mainly copied from the sklearn docs\n",
    "        cm = self.confusion_matrix()\n",
    "        if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "        fig = plt.figure(**kwargs)\n",
    "        plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "        plt.title(title)\n",
    "        tick_marks = np.arange(len(self.vocab))\n",
    "        plt.xticks(tick_marks, self.vocab, rotation=90)\n",
    "        plt.yticks(tick_marks, self.vocab, rotation=0)\n",
    "\n",
    "        if plot_txt:\n",
    "            thresh = cm.max() / 2.\n",
    "            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
    "                coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'\n",
    "                plt.text(j, i, coeff, horizontalalignment=\"center\", verticalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n",
    "\n",
    "        ax = fig.gca()\n",
    "        ax.set_ylim(len(self.vocab)-.5,-.5)\n",
    "\n",
    "        plt.tight_layout()\n",
    "        plt.ylabel('Actual')\n",
    "        plt.xlabel('Predicted')\n",
    "        plt.grid(False)\n",
    "\n",
    "    def most_confused(self, min_val=1):\n",
    "        \"Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences.\"\n",
    "        cm = self.confusion_matrix()\n",
    "        np.fill_diagonal(cm, 0)\n",
    "        res = [(self.vocab[i],self.vocab[j],cm[i,j])\n",
    "                for i,j in zip(*np.where(cm>=min_val))]\n",
    "        return sorted(res, key=itemgetter(2), reverse=True)\n",
    "\n",
    "    def print_classification_report(self):\n",
    "        \"Print scikit-learn classification report\"\n",
    "        d,t = flatten_check(self.decoded, self.targs)\n",
    "        print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=[str(v) for v in self.vocab]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
